import tensorflow as tf
from tensorflow.keras import datasets
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly
# https://stackoverflow.com/questions/57658935/save-jupyter-notebook-with-plotly-express-widgets-displaying
plotly.offline.init_notebook_mode()
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 365s 2us/step
# https://stackoverflow.com/questions/51235508/in-which-folder-on-pc-windows-10-does-load-data-save-a-dataset-in-keras
! ls ~/.keras/datasets/cifar-10-batches-py
batches.meta data_batch_2 data_batch_4 readme.html data_batch_1 data_batch_3 data_batch_5 test_batch
train_images.shape
(50000, 32, 32, 3)
... provides convenient indexing expressivityimage_index = 0
train_images[image_index].shape, train_images[image_index,...].shape
((32, 32, 3), (32, 32, 3))
np.newaxis provides convenient dimensionality expressivity# https://stackoverflow.com/questions/17394882/how-can-i-add-new-dimensions-to-a-numpy-array
train_images[np.newaxis,image_index,...].shape
(1, 32, 32, 3)
# https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/cnn.ipynb#scrollTo=K3PAELE2eSU9
def plot_cifar10(x,y):
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
plt.subplot(111);plt.xticks([]);plt.yticks([]);plt.grid(False)
plt.imshow(x, cmap=plt.cm.binary);plt.xlabel(y)
plot_cifar10(train_images[image_index],class_names[train_labels[image_index][0]])
rgb = {v:k for k,v in enumerate("rgb")}
print(rgb)
image_index = 0
train_images[image_index][..., rgb['r']]
{'r': 0, 'g': 1, 'b': 2}
array([[ 59, 43, 50, ..., 158, 152, 148],
[ 16, 0, 18, ..., 123, 119, 122],
[ 25, 16, 49, ..., 118, 120, 109],
...,
[208, 201, 198, ..., 160, 56, 53],
[180, 173, 186, ..., 184, 97, 83],
[177, 168, 179, ..., 216, 151, 123]], dtype=uint8)
# https://plotly.com/python/3d-scatter-plots/
# https://plotly.com/python/colorscales/
# https://plotly.com/python/creating-and-updating-figures/
# https://community.plotly.com/t/plotly-express-multiple-plots-overlay/31984
def plot_channel(x, h, color, alpha, loc, i=[0], j=[0]):
i_grid,j_grid = np.meshgrid(*[range(i) for i in x[0].shape[:2]])
df = pd.DataFrame(columns=['i','j','h','c'])
for x_,h_,i_,j_ in zip(x,h,i,j):
tmp = pd.DataFrame({'i': i_+i_grid.ravel(), 'j': j_+j_grid.ravel(), 'h': h_+0*i_grid.ravel()})
tmp['c'] = x_.ravel()
df = df.append(tmp)
fig.append_trace(go.Scatter3d(x=df.i, y=df.j, z=df.h, opacity=alpha,
mode='markers', marker={'color':df.c, 'colorscale': color}),*loc)
# a few other helpful posts which ultimately led to my solution above
# https://community.plotly.com/t/specifying-a-color-for-each-point-in-a-3d-scatter-plot/12652
# https://www.reddit.com/r/rstats/comments/g3tulu/is_it_possible_to_vary_alphaopacity_by_group_in_a/
# https://plotly.com/python/3d-scatter-plots/
# .. mode='markers', marker=dict(color='(255,0,0)', size=10) ...
# https://stackoverflow.com/questions/53875880/convert-a-pandas-dataframe-of-rgb-colors-to-hex
# https://stackoverflow.com/questions/46750462/subplot-with-plotly-with-multiple-traces
fig = plotly.subplots.make_subplots(rows=1,cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}]])
plot_channel([train_images[image_index][..., rgb['r']]/255], [0], 'Reds', 0.33, [1,1])
plot_channel([train_images[image_index][..., rgb['g']]/255], [1], 'Greens', 0.33, [1,1])
plot_channel([train_images[image_index][..., rgb['b']]/255], [2], 'Blues', 0.33, [1,1])
plot_channel([train_images[image_index][..., rgb['r']]/255,
train_images[image_index][..., rgb['g']]/255,
train_images[image_index][..., rgb['b']]/255], [0,1,2], 'Greys', 0.33, [1,2], [0]*3, [0]*3)
fig.show()
number_kernels = 10
kernel_width = 5
kernels = np.random.rand(number_kernels*kernel_width*kernel_width)
kernels = kernels.reshape((number_kernels,kernel_width,kernel_width))-0.5
print(kernels.shape)
kernels
(10, 5, 5)
array([[[ 0.39773621, -0.22900528, 0.09881262, 0.39870386,
0.2217557 ],
[ 0.09521832, -0.0603677 , 0.07556154, -0.03685071,
-0.2009564 ],
[-0.23324775, 0.13597168, -0.38572881, 0.17755535,
0.02958676],
[ 0.17361706, -0.28860532, 0.19120053, -0.25548537,
-0.03434457],
[ 0.43822891, -0.17325585, 0.27953384, -0.40592455,
0.22843153]],
[[-0.25219432, 0.32409323, 0.05650003, -0.35506418,
-0.21340797],
[-0.23070086, 0.23789205, -0.113693 , -0.12188067,
-0.28537089],
[ 0.35086779, -0.44085513, 0.18621295, 0.15655813,
-0.08871664],
[ 0.02848524, -0.0364151 , -0.2990787 , -0.46341008,
0.09617642],
[ 0.26849829, -0.36769018, -0.22750486, -0.42586508,
0.35661221]],
[[ 0.41166825, 0.28066959, 0.34855888, 0.15640564,
0.10742895],
[-0.04649512, -0.096581 , 0.05806068, -0.13888157,
0.32628453],
[ 0.12112283, -0.41475131, 0.4523904 , 0.13580666,
-0.22023767],
[-0.16963556, 0.44228029, 0.42031119, 0.41571177,
-0.4084773 ],
[-0.31393954, -0.00194135, 0.19085985, 0.33623604,
-0.33100384]],
[[ 0.4522182 , 0.10008379, 0.13888906, 0.03366733,
0.05187882],
[ 0.49325739, 0.49173976, -0.47655801, -0.04603961,
0.01312172],
[ 0.06674941, 0.12325937, 0.4237374 , 0.22924355,
0.09154674],
[-0.46924761, -0.18808748, -0.26755613, -0.081073 ,
-0.20719813],
[-0.26687857, 0.11433975, -0.11202769, -0.49593877,
-0.34071989]],
[[ 0.35583374, 0.05417797, 0.19429293, 0.04690031,
-0.18076017],
[-0.26390998, 0.40771601, -0.18751396, -0.42328216,
0.4092391 ],
[ 0.20935034, -0.46515889, -0.38332817, -0.45173748,
0.25107204],
[-0.00946868, -0.21123139, -0.09209932, 0.24485565,
-0.36954314],
[ 0.42544599, -0.35006576, 0.08926736, 0.02709029,
-0.36106572]],
[[ 0.1917415 , 0.12923727, -0.09677804, 0.00235128,
0.22352569],
[-0.17820994, -0.11924904, 0.44308445, -0.21832577,
0.05272216],
[-0.04656538, -0.20568385, -0.03680397, -0.36465248,
0.22460203],
[ 0.44451258, 0.29328181, -0.38657691, -0.12423647,
-0.47787287],
[-0.27331273, -0.11125227, 0.42764733, -0.25131675,
-0.23771714]],
[[ 0.03610223, 0.39661193, 0.29929894, 0.09385738,
-0.47940092],
[ 0.10217362, 0.24694502, 0.27857437, 0.15733641,
-0.09451715],
[-0.11832424, 0.45136702, -0.47465776, 0.49306154,
-0.32778601],
[ 0.24639338, -0.46653094, -0.2127155 , -0.08344141,
0.42424004],
[ 0.44881008, 0.23767325, -0.08938394, 0.3618896 ,
-0.05306801]],
[[-0.33074493, -0.34806425, 0.03138761, 0.27801598,
0.39188293],
[-0.11004306, 0.00321699, -0.05659795, -0.11514728,
-0.44331126],
[-0.30439071, 0.256911 , -0.14817768, 0.27302665,
-0.05730085],
[-0.18186033, 0.04368706, -0.18500711, 0.12878156,
-0.17546038],
[-0.36264879, 0.4798238 , -0.10582065, -0.22548848,
0.45551775]],
[[ 0.11273648, -0.1755804 , -0.08959842, 0.47728864,
0.13929403],
[-0.45505838, 0.01442301, 0.49023557, -0.05229527,
-0.42875331],
[ 0.15720984, 0.36420407, -0.00366684, 0.37044311,
0.42731875],
[-0.06769426, -0.27619938, -0.13072469, 0.44154463,
0.34017505],
[ 0.48106994, 0.43965211, 0.095974 , -0.0406033 ,
-0.12272673]],
[[ 0.32695968, -0.0131281 , 0.33555404, 0.08094979,
-0.15762903],
[-0.07571525, 0.30388555, -0.04378966, -0.19016661,
-0.24949091],
[ 0.26619536, -0.19415484, -0.23356821, 0.31933752,
-0.1145632 ],
[ 0.37917397, 0.27573519, 0.4164153 , -0.16388523,
-0.21491698],
[-0.1297217 , -0.30005154, 0.09354702, -0.30788471,
-0.27059385]]])
image_index = 0
channel = 0
i,j = 0,0
train_images[image_index,
i:(i+kernel_width),
j:(j+kernel_width),
channel]
array([[ 59, 43, 50, 68, 98],
[ 16, 0, 18, 51, 88],
[ 25, 16, 49, 83, 110],
[ 33, 38, 87, 106, 115],
[ 50, 59, 102, 127, 124]], dtype=uint8)
kernel_index = 0
kernels[kernel_index,:,:]
array([[ 0.39773621, -0.22900528, 0.09881262, 0.39870386, 0.2217557 ],
[ 0.09521832, -0.0603677 , 0.07556154, -0.03685071, -0.2009564 ],
[-0.23324775, 0.13597168, -0.38572881, 0.17755535, 0.02958676],
[ 0.17361706, -0.28860532, 0.19120053, -0.25548537, -0.03434457],
[ 0.43822891, -0.17325585, 0.27953384, -0.40592455, 0.22843153]])
kernel_index = 0
image_index = 0
channel = 0
i,j = 0,0#1,1
(\
train_images[image_index,
i:(i+kernel_width),
j:(j+kernel_width),
channel] * kernels[kernel_index,:,:]\
).sum()
21.217480140740886
kernels[-1,:2,:]=-1
kernels[-1,2,:]=0
kernels[-1,3:,:]=1
kernels[-1,...]
array([[-1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1.],
[ 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.]])
kernels[-2,:,:2]=-1
kernels[-2,:,2]=0
kernels[-2,:,3:]=1
kernels[-2,...]
array([[-1., -1., 0., 1., 1.],
[-1., -1., 0., 1., 1.],
[-1., -1., 0., 1., 1.],
[-1., -1., 0., 1., 1.],
[-1., -1., 0., 1., 1.]])
convolutional_layer_output = np.zeros(list(np.array(train_images[image_index].shape[:2])-kernel_width+1)+[number_kernels])
print(convolutional_layer_output.shape)
for kernel_index in range(number_kernels):
for i in range(convolutional_layer_output.shape[0]):
for j in range(convolutional_layer_output.shape[1]):
convolutional_layer_output[i,j,kernel_index] = \
(kernels[kernel_index,:,:] *
(train_images[image_index,
i:(i+kernel_width),
j:(j+kernel_width),
channel]/255-0.5)).sum() # notice that this sums over color channels!
(28, 28, 10)
# https://matplotlib.org/3.3.3/api/_as_gen/matplotlib.pyplot.subplots.html
f, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(15,10))
ax1.imshow(convolutional_layer_output[...,-1])
ax2.imshow(train_images[image_index], cmap=plt.cm.binary)
ax3.imshow(convolutional_layer_output[...,-2])
ax4.imshow(kernels[-1,...])
ax5.axis('off')
ax6.imshow(kernels[-2,...])
<matplotlib.image.AxesImage at 0x7fdfa9bd5b80>
convolutional_layer_output[...,kernel_index].ravel()/255
array([-5.82288933e-04, -6.84513407e-04, -3.29586169e-04, -9.80415652e-05,
7.54598706e-05, 2.42407881e-04, 2.20212177e-04, 9.05343299e-05,
1.93653263e-04, 3.62672257e-04, 3.01851907e-04, -2.35108119e-05,
-3.98104030e-05, 2.61242758e-04, 1.36847583e-04, -1.01427202e-04,
7.76067629e-05, 2.99479376e-05, -8.55969128e-05, 1.40801783e-04,
1.27111417e-04, 3.74840056e-04, 3.40629536e-04, 4.84385755e-04,
3.39051678e-04, 3.43372774e-04, 4.08811744e-04, 3.45849741e-04,
-1.02763796e-03, -9.24960791e-04, -6.16190950e-04, -3.06215549e-04,
-6.57485518e-05, 1.79163286e-04, -3.15672379e-05, -1.42006038e-04,
-1.78160904e-04, -1.97611497e-04, -9.81322040e-05, -4.41005711e-04,
-2.98662175e-04, -2.17508259e-04, -3.27253580e-04, -2.00138537e-04,
-2.34310147e-04, -2.50147037e-04, -4.83158986e-04, -1.78155870e-04,
-2.84050998e-04, -2.89411104e-04, -2.14442459e-04, -6.66723762e-05,
7.32002914e-05, 1.79036927e-04, -2.12900705e-05, 1.20162329e-04,
-6.40877225e-04, -6.66447560e-04, -3.39276542e-04, -1.63390326e-04,
-3.29669990e-05, -3.69457047e-06, -1.20266423e-04, -5.33727632e-05,
-2.59227022e-04, -1.64651107e-04, 5.83207505e-05, 3.03025257e-05,
-3.06705928e-05, -1.18985954e-04, -3.25145304e-04, -6.50941434e-05,
-3.79980877e-05, 1.42881238e-04, -4.06912601e-05, -5.18797649e-04,
-1.58588356e-04, -4.80541810e-04, -4.40844696e-05, -5.97057013e-05,
1.59776148e-04, 3.06539262e-04, 1.49756457e-04, 3.18581328e-04,
-4.89073093e-04, -5.68534246e-04, -2.55474796e-04, -1.87666444e-04,
-2.19937517e-04, -1.91682149e-04, -1.06180490e-05, -8.92436594e-05,
-2.62611605e-04, 1.28042361e-04, 4.30938762e-04, 3.39526439e-04,
1.45004644e-04, -3.39323310e-04, -5.25558595e-04, -9.13009623e-05,
5.39131172e-04, 9.18654708e-04, 4.97535081e-04, -3.89174726e-04,
1.14471231e-04, -3.99453075e-04, 7.18530246e-05, -9.50467788e-06,
1.17639662e-04, 3.57637804e-04, 4.36098994e-04, 7.10651209e-05,
-2.55325145e-04, -3.62310251e-04, -2.81903935e-04, -1.31152364e-04,
-1.65991939e-04, -1.15473315e-04, -1.06186973e-04, -1.83959378e-04,
-1.18825715e-04, -3.19933275e-05, 3.61033209e-04, 1.89699530e-04,
-8.17875931e-05, -7.86980664e-04, -5.86751576e-04, -4.77324986e-04,
5.23807157e-04, 1.07525572e-03, 9.52379997e-04, 3.20271567e-04,
-1.59756747e-05, -1.20712053e-04, -4.51918731e-04, -3.99412439e-04,
-1.60003845e-04, -5.14835945e-05, 6.36963387e-05, -1.97236882e-04,
-1.86258932e-05, -3.24626989e-06, 8.25172941e-05, 1.07230951e-04,
-1.40408488e-04, -2.03570991e-04, -2.70921115e-04, -3.14024002e-05,
-1.43455984e-04, 9.96511527e-05, 2.64000361e-04, 2.31470853e-06,
-1.42183379e-04, -5.19307276e-04, -5.92571684e-04, -1.10559395e-03,
-3.87587134e-04, 5.04154823e-06, -1.13328723e-04, -2.89943214e-05,
4.07859437e-05, 4.50704143e-05, -4.26820670e-04, -5.17065814e-04,
-4.55794849e-04, -2.91328238e-04, -5.18737728e-04, -5.98982814e-04,
-6.91532974e-05, 1.56046729e-04, 1.37821939e-04, 8.40423661e-05,
7.44970927e-05, -3.09884838e-05, 1.06320202e-04, 2.49016786e-04,
7.01135087e-06, 4.37966717e-04, -1.21983949e-04, -2.85898100e-04,
-3.05667289e-04, -5.80680310e-04, -5.11946278e-04, -1.03713636e-03,
-7.57620125e-04, -3.76778724e-04, -5.47720256e-04, 9.75701015e-05,
2.44618719e-04, 3.46920949e-04, 4.95312407e-04, -6.77910554e-05,
-3.60290055e-04, -1.09399756e-04, -3.19308493e-04, -3.54208908e-04,
-2.67016680e-05, -2.32987044e-05, -2.86698659e-05, 1.49527450e-04,
2.36364220e-04, -1.54707985e-04, 1.19864591e-04, 1.71477856e-04,
-1.02378719e-04, 1.03927112e-04, -4.50099486e-04, -3.55893389e-04,
-2.46635199e-04, 2.06969128e-04, 5.32891562e-04, -6.45494873e-04,
-2.40997518e-05, -3.22500318e-05, 8.96171018e-05, -7.56450760e-05,
2.45567631e-04, 6.11085708e-04, 7.12411250e-04, 4.80068441e-05,
-3.51203112e-04, -1.21003614e-04, -6.04570055e-05, -1.97920487e-04,
1.13383034e-05, -2.03175358e-04, -1.88369190e-05, 2.24218099e-04,
9.02644705e-05, -1.70106999e-04, 8.65669429e-05, -1.28030551e-04,
-7.78906775e-04, -4.55850382e-04, -6.03263535e-04, -2.50278687e-04,
-1.12379244e-04, 6.02151458e-04, 7.64310312e-04, -2.89756728e-04,
3.87055470e-04, 6.01499829e-06, 1.89629050e-04, -5.74152184e-04,
-1.90914231e-04, -5.89020870e-05, 1.07873138e-04, -1.14517124e-04,
-2.28713634e-04, -4.90153432e-05, -2.61899553e-05, -6.65474902e-05,
1.34884942e-04, -2.15942333e-05, 2.24612427e-04, 2.53365781e-04,
1.10436183e-04, 2.23978532e-04, -5.28459587e-05, -6.04739911e-04,
-1.20615524e-03, -5.25590145e-04, -8.34294182e-04, -5.87070708e-04,
-2.63181596e-04, -2.80816628e-04, 2.26913123e-04, 6.93631843e-05,
5.80230923e-05, -4.24872594e-05, -2.08174554e-04, -2.34910066e-04,
-4.45423156e-04, -5.39194238e-05, 3.45984717e-04, 4.74353677e-07,
7.83887951e-05, 9.88196107e-05, -1.31036756e-05, 1.24942059e-05,
2.25899243e-04, 1.18933159e-04, 5.36150548e-04, 4.04912614e-04,
4.22690376e-04, 2.76288102e-04, -7.94415074e-05, -6.34107235e-04,
-7.01300138e-04, -4.94340500e-04, -7.90158655e-04, -1.65927655e-04,
-5.97397139e-04, -3.54310793e-04, -3.12786710e-05, -5.79805745e-07,
4.65376235e-05, 1.37068454e-04, 1.99704546e-04, 2.37122789e-04,
-1.10856356e-04, 3.37810063e-04, 7.23816771e-04, 4.71777912e-04,
4.69183333e-04, 2.91942791e-04, 2.00498717e-05, 2.91192708e-05,
-3.85224719e-05, -1.42787705e-04, 5.22893352e-04, 1.73231852e-04,
3.51050314e-04, 1.84420577e-04, -3.43727481e-04, -9.37246800e-04,
2.31463139e-04, -2.25697776e-04, -4.98641971e-06, 2.81923941e-04,
-8.39209527e-04, 7.27577306e-05, -2.36841366e-04, 2.03069221e-04,
5.54419057e-04, 4.65680906e-04, 1.05788509e-03, 5.00918039e-04,
3.08697975e-05, -1.28761736e-05, 8.30533885e-04, 6.11775410e-04,
5.13833654e-04, 3.78578030e-04, 4.34730896e-05, -2.01439196e-05,
1.53854969e-04, 1.19171782e-04, 4.11582323e-04, 7.62815179e-05,
4.48240263e-04, -2.33808003e-04, -8.77491114e-04, -5.05906730e-04,
7.46499448e-04, 9.26473713e-05, 7.74729293e-04, 5.49682131e-04,
2.92707775e-04, 4.02148333e-04, 5.85900197e-04, 5.81418042e-04,
4.48401545e-04, 6.46047416e-04, 9.69731615e-04, 6.70967371e-05,
-1.16519567e-04, -5.80157610e-04, 1.44147810e-04, 3.63207741e-04,
5.64538567e-04, 2.61871805e-04, 1.18921213e-04, 1.39507785e-04,
1.42372132e-04, 6.08437383e-05, -8.02767996e-05, -6.70146804e-05,
2.94357410e-04, -5.31789834e-04, -8.62716971e-04, 2.86845591e-04,
5.41490814e-04, 5.91896641e-04, 5.00655178e-04, 8.46579385e-04,
9.30474164e-04, 5.02691464e-04, 9.59971975e-04, 7.00735032e-04,
7.96815687e-04, 6.14524248e-04, 6.38926029e-04, 4.51836987e-04,
1.55731486e-04, -1.38753425e-04, -7.53111011e-05, 3.49126221e-04,
6.19040539e-04, 2.59410318e-04, 2.53308221e-04, 1.76630168e-04,
2.71610738e-04, -2.16750555e-05, -4.37938152e-04, 3.40549206e-04,
-3.87360647e-05, -5.67882193e-04, -1.03130250e-03, 6.14735803e-04,
1.56785163e-04, 3.50496011e-04, 4.53625774e-04, 8.65275330e-04,
6.16323503e-04, 1.00290793e-03, 8.47509286e-04, 9.23433553e-04,
9.29642163e-04, 6.57510227e-04, 7.72772061e-04, 9.74636892e-04,
6.07673834e-04, 8.71683084e-04, 5.21421279e-04, 5.14040239e-04,
9.99361960e-04, 3.65049303e-04, -7.49832941e-06, -4.23599268e-05,
3.56769665e-04, 5.71276602e-05, -3.77520543e-04, -1.64110719e-04,
-4.86978417e-04, -5.09982763e-04, -3.77385360e-04, 5.10193834e-04,
1.50369616e-04, 1.97057633e-04, 7.61710382e-04, 3.60931025e-04,
6.87248499e-04, 1.49941234e-03, 5.41302714e-04, 7.99759165e-04,
1.18402202e-03, 1.22664208e-03, 9.08690005e-04, 7.33058674e-04,
8.90188695e-04, 6.38812202e-04, 7.30227744e-04, 6.66028739e-04,
9.40476359e-04, 3.20894401e-04, -8.30362070e-05, 9.22497685e-06,
3.59042030e-04, 3.38126552e-05, -2.13690102e-04, -6.78899878e-04,
-9.98430032e-05, -4.76402862e-04, 4.17627677e-04, 3.87602033e-04,
3.19173745e-04, 5.43503939e-04, 7.64093171e-04, 4.65473628e-04,
9.75332060e-04, 1.52789731e-03, 7.36368781e-04, 1.36268748e-03,
1.27623405e-03, 1.00707423e-03, 6.16862002e-04, 3.71825787e-04,
8.49764649e-04, -5.83840436e-05, 2.52761083e-04, 7.49806307e-04,
9.10049014e-04, 4.75958734e-04, -6.56407494e-05, 1.12420917e-04,
4.02094699e-04, 4.84211388e-05, -2.80773063e-04, -3.02623898e-04,
-1.34114682e-04, -6.20432610e-07, 1.11905180e-04, -6.81340041e-05,
4.87850856e-04, 1.14855496e-03, 8.23551600e-04, 8.14384628e-04,
1.43375540e-03, 1.48178686e-03, 1.08137609e-03, 9.58721857e-04,
6.92485083e-04, 7.92388835e-04, 4.21303868e-04, 2.76752365e-04,
5.76997510e-04, 2.05480166e-04, -9.98735257e-05, 3.70613217e-04,
6.19782552e-04, 4.80420699e-04, 1.94171794e-04, 2.36541631e-04,
2.42327177e-04, -2.05959996e-05, -4.69592527e-04, -1.57226686e-05,
-4.02510037e-05, 5.59408376e-04, 1.31834135e-04, -1.24414480e-04,
7.00706419e-04, 1.43191358e-03, 1.09588027e-03, 7.24416132e-04,
1.22786159e-03, 8.53452216e-04, 4.30245190e-04, 2.35126187e-04,
4.68821620e-04, 5.61343196e-04, 7.42188696e-04, 7.09534383e-04,
7.78688722e-04, 1.04325225e-03, 3.45391792e-04, 3.79414751e-04,
6.88908626e-04, 3.66880991e-04, 3.18536261e-04, 9.11800514e-05,
1.95198780e-04, -1.45244346e-05, -7.96926904e-06, 3.45475727e-04,
1.12982804e-03, 8.56423805e-04, 7.03008384e-04, 5.32564641e-04,
6.05129245e-04, 1.16448429e-03, 3.35210641e-04, 2.29106998e-04,
5.16659356e-04, 7.02126969e-04, 1.64797729e-04, 2.07009515e-04,
6.49478346e-06, 1.43660819e-04, 8.29198746e-04, 1.04536997e-03,
8.66779942e-04, 1.01833508e-03, 6.11389706e-04, 6.78143783e-05,
2.90308839e-04, 2.52527980e-05, 2.20342801e-05, -2.98274447e-04,
1.29807843e-04, -9.29380588e-05, 2.78882087e-04, 7.36601033e-04,
1.03609516e-03, 5.60851377e-04, 5.36924390e-04, 9.13451578e-04,
2.77120087e-04, 5.12523002e-04, 1.11862348e-04, 1.50090976e-04,
4.61069390e-04, 9.63188219e-04, 6.46237140e-04, 3.32351189e-04,
-6.32634548e-06, 1.24894800e-04, 5.28691202e-04, 8.94814398e-04,
8.36924094e-04, 7.34998880e-04, 1.58870808e-04, -3.11404645e-04,
-1.33097018e-05, 2.92248547e-04, -3.09013282e-05, -5.39977491e-04,
2.98177246e-04, -3.23406893e-05, 9.18559889e-05, 1.89121590e-04,
1.27448471e-05, 4.66621102e-04, 8.54787671e-05, 2.86657906e-04,
3.76793945e-04, 4.25551807e-04, 6.02639473e-04, -4.68952457e-04,
4.27077628e-04, 4.37399163e-04, 3.17049027e-04, 3.82001114e-04,
1.51168091e-04, 8.20095584e-05, 9.71677003e-05, 2.08857462e-04,
2.96505730e-04, 3.43519425e-04, 3.14516570e-04, 2.30561803e-04,
1.01393428e-04, 3.43913581e-04, -9.67517802e-05, -3.14542537e-04,
5.05492246e-04, 4.40357054e-04, 2.01088797e-04, -9.42350260e-05,
-1.41222308e-04, 2.77106010e-04, 1.07345148e-04, 1.40673944e-04,
8.20761414e-04, 1.98704168e-04, 1.97552895e-04, 5.91921640e-05,
1.51226584e-04, 3.53120208e-04, 3.51340810e-05, 1.40092470e-04,
2.08662834e-04, 3.02455478e-04, 2.89354601e-04, 2.73143843e-04,
3.93501395e-04, 3.38158306e-04, 3.83173675e-04, 3.16810585e-04,
2.73482095e-04, -2.75281722e-05, -3.98376159e-04, -1.11148936e-04,
5.92092273e-04, 6.86728015e-04, 4.29393104e-04, 2.78687814e-04,
1.90228073e-04, 1.39959930e-04, 2.88350786e-04, 3.72412747e-04,
5.90657453e-04, 6.54407362e-04, 1.02543271e-04, 3.37669314e-04,
2.47827347e-04, 6.02545966e-04, 1.53357767e-04, 3.05806085e-05,
1.52906901e-04, 1.71972944e-04, 4.06660304e-04, 4.65323386e-05,
-1.88190079e-04, -3.74674667e-04, -2.54540680e-04, -5.79770081e-05,
-1.08049650e-04, -6.23061430e-04, -4.93178059e-04, 1.42391835e-04,
5.40413658e-04, 5.21301634e-04, 5.47768892e-04, 5.48178205e-04,
4.87377366e-04, 4.03951821e-04, 4.60435809e-04, 7.26467340e-04,
7.78278217e-04, 1.20289060e-03, 6.19007402e-04, 5.14044253e-04,
8.63994193e-04, 4.18160930e-04, 3.20084944e-04, 3.05974332e-04,
2.56058464e-04, -1.22791105e-05, -1.41145051e-04, -4.18995479e-04,
-3.33741673e-04, -4.22183599e-04, -5.12287225e-04, -1.92251501e-04,
-1.12719866e-04, -6.03090153e-04, -3.28985344e-04, 2.09213606e-04,
3.97270928e-04, 2.75406709e-04, 4.09145899e-04, 5.27667692e-04,
5.18433934e-04, 5.90887006e-04, 6.60390741e-04, 5.28826480e-04,
3.93974645e-04, 5.64822089e-04, 4.56281537e-04, 3.54463383e-04,
6.40181496e-04, 3.77022319e-04, 8.15165534e-05, 1.90257747e-04,
2.19623014e-05, -1.45490565e-04, 6.27731219e-06, 1.58211421e-04,
1.83902639e-04, 1.80587681e-05, -2.50594453e-04, -8.06151503e-05,
3.18418134e-04, -2.44315754e-04, -9.39044237e-05, 1.82448945e-04,
1.01033296e-04, 1.21548589e-04, 1.52620100e-04, 3.04848777e-04,
4.03837253e-04, 5.04966097e-04, 5.31355102e-04, 2.47305631e-04,
2.30835927e-04, 3.33359248e-04, 2.66858021e-04, 1.15188725e-04,
5.36196053e-05, 2.58301796e-04, 2.65913332e-04, 1.87662372e-04,
-2.82239961e-05, 5.59184292e-05, 2.62509758e-04, 3.66445854e-04,
3.34622246e-04, 1.95438624e-04, 9.13673880e-05, 3.16074621e-05,
5.55815020e-04, 3.08369310e-04, 9.11804397e-05, 1.74578158e-05,
3.38337953e-04, 2.27793741e-05, -8.86946399e-05, 7.11466580e-05,
1.41896985e-04, 3.25138708e-04, 4.81012707e-04, 4.23569616e-04,
4.09992214e-04, 4.04462230e-04, 1.90547114e-04, 1.58564512e-04,
1.64529629e-04, 1.29461945e-04, 2.55987972e-04, 2.65833004e-04,
2.04358622e-04, 6.50225077e-05, 7.88774741e-05, 3.74023469e-04,
4.51773059e-04, 2.71046351e-04, 1.90098030e-04, 3.83752591e-04,
6.72610035e-04, 9.98458817e-04, 4.07567219e-04, -9.47633841e-05])
# might try to animate this later:
# https://plotly.com/python/animations/
# https://plotly.com/python/v3/gapminder-example/
fig = plotly.subplots.make_subplots(rows=1,cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}]])
plot_channel([train_images[image_index][..., rgb['r']]/255,
train_images[image_index][..., rgb['g']]/255,
train_images[image_index][..., rgb['b']]/255], [0,1,2], 'Greys', 0.33, [1,1], [0]*3, [0]*3)
for kern_indx in range(number_kernels):
plot_channel([convolutional_layer_output[...,kern_indx]/255], [3*kern_indx],
px.colors.named_colorscales()[kern_indx], 0.1, [1,2])
kernel_index=2
i,j=17,17
plot_channel([kernels[kernel_index,...]/255]*3, list(range(3)),
px.colors.named_colorscales()[kernel_index], 1, [1,1], i=[i]*3, j=[j]*3)
plot_channel([np.array([[0],[1],[0]]).dot(np.array([[0,1,0]]))], [3*kernel_index],
px.colors.named_colorscales()[kernel_index], 1, [1,2], i=[i], j=[j])
fig.show()
# https://stats.stackexchange.com/questions/335332/why-use-matrix-transpose-in-gradient-descent
pooling_width = 4
pooling_layer_result = np.zeros(list((np.array(convolutional_layer_output.shape[:2])/pooling_width).astype(int))+[convolutional_layer_output.shape[2]])
print(pooling_layer_result.shape)
i_pooling_layer_result = 0*pooling_layer_result
j_pooling_layer_result = 0*pooling_layer_result
for k in range(number_kernels):
for ii,i in enumerate(range(0, convolutional_layer_output.shape[1], pooling_width)):
for jj,j in enumerate(range(0, convolutional_layer_output.shape[1], pooling_width)):
pooling_layer_result[ii,jj,k] = convolutional_layer_output[i:(i+pooling_width),j:(j+pooling_width),k].max()
i_pooling_layer_result[ii,jj,k] = ii
j_pooling_layer_result[ii,jj,k] = jj
(7, 7, 10)
kern_indx = 8
tmp = [convolutional_layer_output[i:(i+pooling_width),j:(j+pooling_width),kern_indx]/255
for j in range(0,28,8) for i in range(4,28,8)]
tmp_i = [i for j in range(0,28,8) for i in range(4,28,8)]
tmp_j = [j for j in range(0,28,8) for i in range(4,28,8)]
tmp += [convolutional_layer_output[i:(i+pooling_width),j:(j+pooling_width),kern_indx]/255
for i in range(0,28,8) for j in range(4,28,8)]
tmp_i += [i for i in range(0,28,8) for j in range(4,28,8)]
tmp_j += [j for i in range(0,28,8) for j in range(4,28,8)]
jmp = [pooling_layer_result[i:(i+1),j:(j+1),kern_indx]/255
for j in range(0,7,2) for i in range(1,7,2)]
jmp_i = [i for j in range(0,7,2) for i in range(1,7,2)]
jmp_j = [j for j in range(0,7,2) for i in range(1,7,2)]
jmp += [pooling_layer_result[i:(i+1),j:(j+1),kern_indx]/255
for i in range(0,7,2) for j in range(1,7,2)]
jmp_i += [i for i in range(0,7,2) for j in range(1,7,2)]
jmp_j += [j for i in range(0,7,2) for j in range(1,7,2)]
kern_indx = 2
tmp = [tmp]
tmp_i = [tmp_i]
tmp_j = [tmp_j]
tmp.append([convolutional_layer_output[i:(i+pooling_width),j:(j+pooling_width),kern_indx]/255
for j in range(0,28,8) for i in range(4,28,8)])
tmp_i.append([i for j in range(0,28,8) for i in range(4,28,8)])
tmp_j.append([j for j in range(0,28,8) for i in range(4,28,8)])
tmp[-1] += [convolutional_layer_output[i:(i+pooling_width),j:(j+pooling_width),kern_indx]/255
for i in range(0,28,8) for j in range(4,28,8)]
tmp_i[-1] += [i for i in range(0,28,8) for j in range(4,28,8)]
tmp_j[-1] += [j for i in range(0,28,8) for j in range(4,28,8)]
jmp = [jmp]
jmp_i = [jmp_i]
jmp_j = [jmp_j]
jmp.append([pooling_layer_result[i:(i+1),j:(j+1),kern_indx]/255
for j in range(0,7,2) for i in range(1,7,2)])
jmp_i.append([i for j in range(0,7,2) for i in range(1,7,2)])
jmp_j.append([j for j in range(0,7,2) for i in range(1,7,2)])
jmp[-1] += [pooling_layer_result[i:(i+1),j:(j+1),kern_indx]/255
for i in range(0,7,2) for j in range(1,7,2)]
jmp_i[-1] += [i for i in range(0,7,2) for j in range(1,7,2)]
jmp_j[-1] += [j for i in range(0,7,2) for j in range(1,7,2)]
fig = plotly.subplots.make_subplots(rows=1,cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}]])
for kern_indx in range(number_kernels):
plot_channel([convolutional_layer_output[...,kern_indx]/255], [3*kern_indx],
'Greys', 0.1, [1,1])
for kern_indx in range(number_kernels):
plot_channel([pooling_layer_result[...,kern_indx]/255], [3*kern_indx],
'Greys', 0.2, [1,2])
for k,kern_indx in enumerate([8,2]):
plot_channel(tmp[k], [3*kern_indx]*len(tmp[k]), px.colors.named_colorscales()[kern_indx],
1, [1,1], i=tmp_i[k], j=tmp_j[k])
plot_channel(jmp[k], [3*kern_indx]*len(jmp[k]), px.colors.named_colorscales()[kern_indx],
1, [1,2], i=jmp_i[k], j=jmp_j[k])
fig.show()
padding='same'?elo to relo but this might not be the best choiceimport numpy as np
mean = np.mean(train_images,axis=(0,1,2,3))
std = np.std(train_images,axis=(0,1,2,3))
train_images = (train_images-mean)/(std+1e-7)
test_images = (test_images-mean)/(std+1e-7)
# Data Augmentation
# https://keras.io/api/preprocessing/image/
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rotation_range=15,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True)
datagen.fit(train_images)
# https://appliedmachinelearning.blog/2018/03/24/achieving-90-accuracy-in-object-recognition-task-on-cifar-10-dataset-with-keras-convolutional-neural-networks/
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization
from keras.layers import Conv2D, MaxPooling2D
from keras import regularizers
weight_decay = 1e-4
model = Sequential()
# Layer 1: Initial Convolution
model.add(Conv2D(32, (3,3), padding='same', input_shape=(32, 32, 3),
kernel_regularizer=regularizers.l2(weight_decay)))
#model.add(Activation('elu'))
#model.add(BatchNormalization())
model.add(BatchNormalization())
model.add(Activation('relu'))
# Layer 2: Convolution on Initial Convolution followed by Pooling with Dropout
model.add(Conv2D(32, (3,3), padding='same',
kernel_regularizer=regularizers.l2(weight_decay)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
# Layer 3: Plain Convolution again
model.add(Conv2D(64, (3,3), padding='same',
kernel_regularizer=regularizers.l2(weight_decay)))
model.add(BatchNormalization())
model.add(Activation('relu'))
s
# Layer 4: Convolution on Convolution followed by Pooling with Dropout, again
model.add(Conv2D(64, (3,3), padding='same',
kernel_regularizer=regularizers.l2(weight_decay)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.3))
# Layer 5: another Plain Convolution
model.add(Conv2D(128, (3,3), padding='same',
kernel_regularizer=regularizers.l2(weight_decay)))
model.add(BatchNormalization())
model.add(Activation('relu'))
# Layer 6: another Convolution-Pooling-Dropout layer
model.add(Conv2D(128, (3,3), padding='same',
kernel_regularizer=regularizers.l2(weight_decay)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.4))
# Layer 7: final
model.add(Flatten())
model.add(layers.Dense(10))
from keras.callbacks import LearningRateScheduler
def lr_schedule(epoch):
lrate = 0.0005
if epoch > 75:
lrate = 0.00025
if epoch > 125:
lrate = 0.00001
return lrate
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.001, decay=1e-6),#'adam',
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),#tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(datagen.flow(train_images, train_labels, batch_size=64),
epochs=150, validation_data=(test_images, test_labels),
callbacks=[LearningRateScheduler(lr_schedule)])
# SPARSE REPRESENTATION WON'T WORK
# WE NEED TO CONVERT TO CATEGORICAL
# https://keras.io/api/preprocessing/image/
# https://appliedmachinelearning.blog/2018/03/24/achieving-90-accuracy-in-object-recognition-task-on-cifar-10-dataset-with-keras-convolutional-neural-networks/
from keras.utils import np_utils
train_labels = np_utils.to_categorical(train_labels, 10)
test_labels = np_utils.to_categorical(test_labels, 10)
zero-padding can keep the network from shrinking with each layer
biases are typically attached according to the natural structures of the convolutional archetecture, but they could also be set to be unique at each locality, for example.
invariance:
variable sized inputs: since convolution is "swept" over or across the input the input shapes don't initially matter in that you can still extract features across the input using a kernel; and using a predefined fixed number of pooling partions with implied regions determined by relative percentages that rescale for different means your pooling function will always produce the same sized output
Cross-correlation (which is actually what is often meant by "convoltution" in NN contexts) is a slight variant of this: $(f*k)(t_0,s_0) = \sum_t \sum_s f(t_0+t,s_0+s) k(t, s)$